{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {
    "scrolled": true
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "/Users/jwen/Projects/road_building_extraction\n"
     ]
    }
   ],
   "source": [
    "cd .."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {
    "scrolled": false
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "/Users/jwen/Projects/road_building_extraction/src\n"
     ]
    }
   ],
   "source": [
    "cd src"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [],
   "source": [
    "import sys\n",
    "import warnings\n",
    "warnings.simplefilter(\"ignore\", (UserWarning, FutureWarning))\n",
    "\n",
    "from matplotlib import pyplot as plt\n",
    "from torch.utils.data import DataLoader\n",
    "from skimage import io\n",
    "from PIL import Image\n",
    "from tqdm import tqdm\n",
    "from torchvision import transforms\n",
    "\n",
    "import pandas as pd\n",
    "import tensorflow as tf\n",
    "import glob\n",
    "import numpy as np\n",
    "import random\n",
    "import torch\n",
    "\n",
    "from utils import data_utils\n",
    "from utils import augmentation as aug\n",
    "\n",
    "%matplotlib inline"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Data directory structure\n",
    "We have pulled the data that we need, but the structure of the data will be difficult to deal with when we are training and testing. Create a .csv with the file paths to the images and the subdirectories"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 107,
   "metadata": {},
   "outputs": [],
   "source": [
    "# one csv for each of the main datasets\n",
    "mass_buildings = glob.glob('/Users/jwen/Projects/road_building_extraction/data/mass_buildings/**/*.tiff', recursive=True)\n",
    "mass_roads = glob.glob('/Users/jwen/Projects/road_building_extraction/data/mass_roads/**/*.tiff', recursive=True)\n",
    "mass_roads_crop = glob.glob('/Users/jwen/Projects/road_building_extraction/data/mass_roads_crop/**/*.tiff', recursive=True)\n",
    "\n",
    "mass_buildings_df = pd.DataFrame(mass_buildings)\n",
    "mass_roads_df = pd.DataFrame(mass_roads)\n",
    "mass_roads_crop_df = pd.DataFrame(mass_roads_crop)\n",
    "\n",
    "mass_buildings_df.rename(columns={0:'file_path'}, inplace=True)\n",
    "mass_roads_df.rename(columns={0:'file_path'}, inplace=True)\n",
    "mass_roads_crop_df.rename(columns={0:'file_path'}, inplace=True)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 108,
   "metadata": {},
   "outputs": [],
   "source": [
    "# create new column with the split \n",
    "mass_buildings_df['sat_img_path'] = mass_buildings_df['file_path'].apply(lambda x: x.split('/')[-1])\n",
    "mass_buildings_df['map_img_path'] = mass_buildings_df['file_path'].apply(lambda x: x.split('/')[-1][:-1])\n",
    "mass_buildings_df['sat_map'] = mass_buildings_df['file_path'].apply(lambda x: x.split('/')[-2])\n",
    "mass_buildings_df['train_valid_test'] = mass_buildings_df['file_path'].apply(lambda x: x.split('/')[-3])\n",
    "\n",
    "mass_roads_df['sat_img_path'] = mass_roads_df['file_path'].apply(lambda x: x.split('/')[-1])\n",
    "mass_roads_df['map_img_path'] = mass_roads_df['file_path'].apply(lambda x: x.split('/')[-1][:-1])\n",
    "mass_roads_df['sat_map'] = mass_roads_df['file_path'].apply(lambda x: x.split('/')[-2])\n",
    "mass_roads_df['train_valid_test'] = mass_roads_df['file_path'].apply(lambda x: x.split('/')[-3])\n",
    "\n",
    "mass_roads_crop_df['sat_img_path'] = mass_roads_crop_df['file_path'].apply(lambda x: x.split('/')[-1])\n",
    "mass_roads_crop_df['map_img_path'] = mass_roads_crop_df['file_path'].apply(lambda x: x.split('/')[-1][:-1])\n",
    "mass_roads_crop_df['sat_map'] = mass_roads_crop_df['file_path'].apply(lambda x: x.split('/')[-2])\n",
    "mass_roads_crop_df['train_valid_test'] = mass_roads_crop_df['file_path'].apply(lambda x: x.split('/')[-3])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 112,
   "metadata": {},
   "outputs": [],
   "source": [
    "# create csv file with the files paths to the data\n",
    "# mass_buildings_df.to_csv('/Users/jwen/Projects/road_building_extraction/data/mass_buildings/mass_buildings.csv', index=False)\n",
    "# mass_roads_df.to_csv('/Users/jwen/Projects/road_building_extraction/data/mass_roads/mass_roads.csv', index=False)\n",
    "# mass_roads_crop_df.to_csv('/Users/jwen/Projects/road_building_extraction/data/mass_roads_crop/mass_roads_crop.csv', index=False)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 104,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "13861"
      ]
     },
     "execution_count": 104,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "# total observations for new cropped dataset\n",
    "len(mass_roads_crop_df)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Testing Data Loading"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [],
   "source": [
    "# load data with transformatons\n",
    "# mass_dataset_train = MassRoadBuildingDataset('/Users/jwen/Projects/road_building_extraction/data/mass_roads/mass_roads.csv','mass_roads','train',\n",
    "#                                        transform=transforms.Compose([RescaleTarget(268), RandomCropTarget(238), ToTensorTarget()]))\n",
    "\n",
    "\n",
    "# mass_dataset_val = MassRoadBuildingDataset('/Users/jwen/Projects/road_building_extraction/data/mass_roads/mass_roads.csv','mass_roads','valid',\n",
    "#                                        transform=transforms.Compose([RescaleTarget(268), ToTensorTarget()]))\n",
    "\n",
    "data_path = '/Users/jwen/Projects/road_building_extraction/data/mass_roads_crop/mass_roads_crop.csv'\n",
    "data_set = 'mass_roads_crop'\n",
    "\n",
    "mass_dataset_train = data_utils.MassRoadBuildingDataset(data_path, data_set, 'train',transform=transforms.Compose([aug.ToTensorTarget()]))\n",
    "\n",
    "mass_dataset_val = data_utils.MassRoadBuildingDataset(data_path, data_set, 'valid')\n",
    "    \n",
    "# RandomRotationTarget(45,resize=True), NormalizeTarget(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [],
   "source": [
    "dataloader = DataLoader(mass_dataset_train, batch_size=3, num_workers=4)\n",
    "data_batch = next(iter(dataloader))\n",
    "\n",
    "dataloader_valid = DataLoader(mass_dataset_val, batch_size=6, num_workers=4)\n",
    "data_batch_valid = next(iter(dataloader_valid))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Crop Images"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 62,
   "metadata": {},
   "outputs": [],
   "source": [
    "def img_crop_coordinates(img, output_size):\n",
    "    w, h, c = img.shape\n",
    "    th, tw = (output_size, output_size)\n",
    "    if w == tw and h == th:\n",
    "        return 0, 0\n",
    "\n",
    "    i = random.randint(0, h - th)\n",
    "    j = random.randint(0, w - tw)\n",
    "    \n",
    "    return i, j\n",
    "\n",
    "def img_crop(csv_df, output_size, num_crops):\n",
    "    \"\"\" Create crops of the same height and width\"\"\"\n",
    "\n",
    "    # filter out the images with missing data\n",
    "    filtered_csv_df = csv_df[csv_df['sat_map']!='missing']\n",
    "    \n",
    "    # max tries counter\n",
    "    counter=0\n",
    "    \n",
    "    for row in tqdm(filtered_csv_df.itertuples()):\n",
    "\n",
    "        sat_img = io.imread(row[1])\n",
    "        map_img = io.imread('/'.join(row[1].split('/')[:-2])+'/map/'+row[3])\n",
    "        \n",
    "        for num_crop in range(num_crops):\n",
    "            \n",
    "            i, j = img_crop_coordinates(sat_img, output_size)\n",
    "\n",
    "            cropped_sat_img = sat_img[i:i+output_size, j:j+output_size]\n",
    "            cropped_map_img = map_img[i:i+output_size, j:j+output_size]\n",
    "             \n",
    "            # have 60% of the crops have more then 3% roads in the image (deal with class imbalance)\n",
    "            if (num_crop <= int(0.6*num_crops)):\n",
    "                while (np.sum(cropped_map_img)/(output_size**2*255)<0.03) and counter<250:\n",
    "                    cropped_sat_img = sat_img[i:i+output_size, j:j+output_size]\n",
    "                    cropped_map_img = map_img[i:i+output_size, j:j+output_size]\n",
    "                    counter+=1\n",
    "\n",
    "            final_sat_img = Image.fromarray(cropped_sat_img)\n",
    "            final_map_img = Image.fromarray(cropped_map_img)\n",
    "                    \n",
    "            final_sat_img.save('/'.join(row[1].split('/')[:-4])+'/mass_roads_crop/'+row[4]+'/sat/'+\"{}_\".format(str(num_crop))+row[2])\n",
    "            final_map_img.save('/'.join(row[1].split('/')[:-4])+'/mass_roads_crop/'+row[4]+'/map/'+\"{}_\".format(str(num_crop))+row[3])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 27,
   "metadata": {},
   "outputs": [],
   "source": [
    "img_csv_df_path = '/Users/jwen/Projects/road_building_extraction/data/mass_roads/mass_roads.csv'"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 63,
   "metadata": {},
   "outputs": [],
   "source": [
    "img_df = pd.read_csv(img_csv_df_path)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "img_crop(img_df, 256, 15)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.6.4"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
